package cn.kennylee.codehub.mongodb.das.extension;

import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.lang.Assert;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import cn.kennylee.codehub.mongodb.das.entity.MongoDbEo;
import com.mongodb.client.result.UpdateResult;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.mongodb.core.BulkOperations;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;

import java.io.Serializable;
import java.util.*;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static cn.kennylee.codehub.mongodb.das.utils.MongoDbDasHelper.*;


/**
 * <p>抽象基础MongoDB数据访问层，提供最基础的数据操作方法</p>
 * <p>Created on 2024/12/7.</p>
 *
 * @author kennylee
 * @since 0.0.1
 */
@Slf4j
public abstract class ComMongoDas<T extends MongoDbEo<P>, P extends Serializable> {

    private MongoTemplate mongoTemplate;

    /**
     * 保存对象
     *
     * @param objectToSave 要保存的对象
     * @return 保存后的对象
     */
    public T save(T objectToSave) {
        if (log.isDebugEnabled()) {
            log.debug("保存对象：{}", JSONUtil.toJsonStr(objectToSave));
        }
        return mongoTemplate.save(objectToSave);
    }

    /**
     * <p>插入（批量），错误自动回滚。</p>
     * <p>由于MongoDB非事务性，不支持回滚，但为了保证批量操作的数据一致性，发生异常时会进行手动回滚数据</p>
     *
     * @param entityList 实体对象集合
     * @return 插入的条目数
     */
    public int saveBatch(Collection<T> entityList) {
        return saveBatch(entityList, true);
    }

    /**
     * 插入（批量），由于MongoDB非事务性，默认不支持回滚；此方法扩展支持异常回滚，回滚时会删除已插入的数据
     *
     * @param entityList      实体对象集合
     * @param rollbackOnError 异常时是否进行回滚
     * @return 插入的条目数
     */
    public int saveBatch(Collection<T> entityList, boolean rollbackOnError) {
        if (CollUtil.isEmpty(entityList)) {
            log.warn("保存对象为空");
            return 0;
        }

        Collection<T> inserts = new ArrayList<>(entityList.size());

        // 批量插入操作
        try {
            inserts.addAll(mongoTemplate.insertAll(entityList));
        } catch (Exception e) {
            log.error(StrUtil.format("批量插入操作发生异常，异常信息: {}", e.getMessage()), e);
            if (rollbackOnError) {
                // 手动回滚
                try {
                    long count = this.removeByDocIds(entityList);
                    log.warn("批量插入操作发生异常，已回滚{}条记录", count);
                } catch (Exception ex) {
                    log.error(StrUtil.format("批量插入操作，异常回滚操作时也发生异常，异常信息: {}", ex.getMessage()), ex);
                }
            }
            // 依然抛出异常
            throw e;
        }

        return inserts.size();
    }

    /**
     * <p>插入（批量），非事务性，不支持回滚，忽略错误</p>
     * <p>入参列表需要先定义主键值</p>
     *
     * @param entityList 实体对象集合
     * @return 插入的条目数
     */
    public int saveBatchIgnoreErrors(Collection<T> entityList) {
        if (CollUtil.isEmpty(entityList)) {
            log.warn("保存对象为空");
            return 0;
        }

        entityList.parallelStream().forEach(entity -> {
            Assert.notNull(entity, "保存对象不能为空");
            Assert.notNull(entity.getId(), "保存对象主键不能为空");
        });

        BulkOperations bulkOperations = mongoTemplate.bulkOps(BulkOperations.BulkMode.ORDERED, getEnityClass());
        bulkOperations.insert(List.copyOf(entityList));
        try {
            // 执行批量操作
            var bulkWriteResult = bulkOperations.execute();
            return bulkWriteResult.getInsertedCount();
        } catch (Exception e) {
            log.error(StrUtil.format("批量操作发生失败，忽略错误继续保存模式，异常信息: {}", e.getMessage()), e);
            return (int) countByIds(getIds(entityList));
        }
    }

    /**
     * <p>插入（批量），分批次插入</p>
     * <p>由于MongoDB非事务性，默认不支持回滚；此方法扩展支持异常回滚，回滚时会删除已插入的数据</p>
     *
     * @param entityList 实体对象集合
     * @param batchSize  插入批次数量; 小于1时，代表不分批
     * @return 插入的条目数
     */
    public int saveBatch(Collection<T> entityList, int batchSize) {
        return saveBatch(entityList, batchSize, false, true);
    }

    /**
     * <p>插入（批量），分批次插入，忽略错误</p>
     * <p>非事务性，不支持回滚</p>
     *
     * @param entityList 实体对象集合
     * @param batchSize  插入批次数量; 小于1时，代表不分批
     * @return 插入的条目数
     */
    public int saveBatchIgnoreErrors(Collection<T> entityList, int batchSize) {
        return saveBatch(entityList, batchSize, true, false);
    }

    /**
     * 分片进行批量插入，非事务性，不支持回滚
     *
     * @param entityList      实体对象集合
     * @param batchSize       批次数量; 小于1时，代表不分批
     * @param ignoreErrors    是否忽略错误继续
     * @param rollbackOnError 异常时是否进行回滚，仅ignoreErrors为false时生效
     * @return 插入的条目数
     */
    protected int saveBatch(Collection<T> entityList, int batchSize, boolean ignoreErrors, boolean rollbackOnError) {

        if (CollUtil.isEmpty(entityList)) {
            log.warn("保存对象为空");
            return 0;
        }

        AtomicInteger count = new AtomicInteger();

        List<List<T>> splitEntityList = batchSize < 1 ? List.of(List.copyOf(entityList)) :
            CollUtil.split(entityList, batchSize);

        // 记录已插入的主键，线程安全
        final Set<P> ids = new CopyOnWriteArraySet<>();

        // 回滚的异常
        AtomicReference<RuntimeException> rollbackException = new AtomicReference<>();

        splitEntityList.parallelStream().forEach(entities -> {
            if (Thread.currentThread().isInterrupted()) {
                log.debug("saveBatch线程已中断，跳过后续操作");
                return;
            }

            int splitInsertCount;
            if (ignoreErrors) {
                splitInsertCount = this.saveBatchIgnoreErrors(entities);
            } else {
                // 分片操作，不单独回滚
                try {
                    splitInsertCount = this.saveBatch(entities, false);
                } catch (RuntimeException e) {
                    splitInsertCount = 0;
                    // 需要回滚
                    if (rollbackOnError) {
                        // 记录异常
                        rollbackException.set(e);
                        log.error(StrUtil.format("批量插入操作发生异常，异常信息: {}", e.getMessage()), e);
                        // 处理中断异常，重新设置中断标志
                        Thread.currentThread().interrupt();
                    } else {
                        // 直接抛出异常
                        throw e;
                    }
                }
            }
            log.debug("分片插入了{}条记录", splitInsertCount);
            count.addAndGet(splitInsertCount);

            // 记录已插入的主键
            ids.addAll(getIds(entities));
        });

        // 如果不忽略错误，且发生异常，需要回滚
        if (Objects.nonNull(rollbackException.get())) {
            // 手动回滚
            try {
                long removeCount = this.removeByIds(ids);
                log.warn("批量插入操作发生异常，已回滚{}条记录", removeCount);
            } catch (Exception ex) {
                log.error(StrUtil.format("批量插入操作，异常回滚操作时也发生异常，异常信息: {}", ex.getMessage()), ex);
            }
            // 依然抛出异常
            throw rollbackException.get();
        }

        if (log.isDebugEnabled()) {
            log.debug("总批量插入了{}条记录", count.get());
        }

        return count.get();
    }

    /**
     * 根据 ID 删除
     *
     * @param id 主键ID
     */
    public boolean removeById(P id) {
        Query query = new Query();
        query.addCriteria(Criteria.where(MONGO_DB_ID_NAME).is(id));
        long deletedCount = mongoTemplate.remove(query, getEnityClass()).getDeletedCount();
        log.debug("删除了{}条记录", deletedCount);
        return deletedCount > 0;
    }

    /**
     * 根据 IDs 删除
     *
     * @param ids 主键ID列表
     * @return 删除的条目数
     */
    public long removeByIds(@NonNull Collection<P> ids) {
        if (ids.isEmpty()) {
            log.warn("删除的主键列表为空");
            return 0;
        }
        Query query = new Query();
        query.addCriteria(Criteria.where(MONGO_DB_ID_NAME).in(Set.copyOf(CollUtil.removeNull(ids))));
        long deletedCount = mongoTemplate.remove(query, getEnityClass()).getDeletedCount();
        log.debug("删除了{}条记录", deletedCount);
        return deletedCount;
    }

    /**
     * 根据文档列表中的主键进行批量删除
     *
     * @param docs 文档列表，必须包含主键ID
     * @return 删除的条目数
     */
    public long removeByDocIds(@NonNull Collection<T> docs) {
        Collection<P> ids = getIds(docs);
        if (log.isDebugEnabled()) {
            log.debug("删除对象的主键列表: {}", JSONUtil.toJsonStr(ids));
        }
        if (ids.isEmpty()) {
            log.warn("删除对象主键为空");
            return 0;
        }
        return removeByIds(ids);
    }

    /**
     * 根据文档列表中的主键列表，过滤null值
     *
     * @param docs 文档列表，必须包含主键ID
     * @param <T>  文档类型
     * @param <P>  主键类型
     * @return 主键列表
     */
    @NonNull
    private static <T extends MongoDbEo<P>, P extends Serializable> Set<P> getIds(@NonNull Collection<T> docs) {
        return docs.parallelStream().map(MongoDbEo::getId).filter(Objects::nonNull).collect(Collectors.toSet());
    }

    /**
     * 根据 ID 查询
     *
     * @param id 主键ID
     */
    @Nullable
    public T findById(P id) {
        return mongoTemplate.findById(id, getEnityClass());
    }

    /**
     * 根据 ID列表 查询
     *
     * @param ids 主键列表
     */
    @NonNull
    public List<T> findByIds(Collection<P> ids) {
        if (CollUtil.isEmpty(ids)) {
            return Collections.emptyList();
        }
        return CollUtil.emptyIfNull(mongoTemplate.find(getInPksQuery(ids), getEnityClass()));
    }

    /**
     * 统计匹配主键的条目数
     *
     * @param ids 主键列表
     * @return 条目数
     */
    public long countByIds(Collection<P> ids) {
        if (CollUtil.isEmpty(ids)) {
            return 0L;
        }
        return mongoTemplate.count(getInPksQuery(ids), getEnityClass());
    }

    /**
     * 根据实例条件进行统计
     *
     * @param example 查询实例
     * @return 条目数
     */
    public long count(T example) {
        Query query = buildQuery(example, getEnityClass());
        return countByQuery(query);
    }

    /**
     * 根据查询条件进行统计
     *
     * @param query 查询条件
     * @return 条目数
     */
    public long countByQuery(Query query) {
        return countByQuery(query, getEnityClass());
    }

    /**
     * 根据实例条件进行统计
     *
     * @param query       查询条件
     * @param entityClass 实体类
     * @return 条目数
     */
    protected long countByQuery(Query query, Class<T> entityClass) {
        if (isEmptyCriteria(query)) {
            return getMongoTemplate().estimatedCount(entityClass);
        } else {
            return getMongoTemplate().count(query, entityClass);
        }
    }

    /**
     * 判断查询条件是否为空
     *
     * @param query 查询条件
     * @return 是否为空
     */
    public static boolean isEmptyCriteria(@NonNull Query query) {
        return query.getQueryObject().isEmpty();
    }

    /**
     * 判断排序条件是否为空
     *
     * @param query 查询条件
     * @return 是否为空
     */
    public static boolean isEmptySort(@NonNull Query query) {
        return query.getSortObject().isEmpty();
    }

    /**
     * 根据实例条件进行删除
     *
     * @param example 查询实例
     * @return 条目数
     */
    public long remove(T example) {
        Assert.isFalse(ObjectUtil.isEmpty(example), "删除实例不能为空");
        Query query = buildQuery(example, getEnityClass());
        return mongoTemplate.remove(query, getEnityClass()).getDeletedCount();
    }

    /**
     * 根据 ID 更新
     *
     * @param id            主键ID
     * @param updateExample 更新内容的键值实例
     * @return 是否更新成功
     */
    public boolean updateById(P id, T updateExample) {
        Optional<UpdateDefinition> updateDefinitionOpt = buildUpdate(updateExample);

        if (Boolean.FALSE == updateDefinitionOpt.isPresent()) {
            log.debug("更新对象为空");
            return false;
        }

        UpdateResult updateResult = mongoTemplate.updateFirst(getPkQuery(id), updateDefinitionOpt.get(), getEnityClass());
        if (log.isDebugEnabled()) {
            log.debug("更新了{}条记录", updateResult.getMatchedCount());
        }
        return updateResult.getMatchedCount() > 0;
    }

    /**
     * 根据 IDs 批量更新，非事务性，不支持回滚
     *
     * @param ids           主键ID集合
     * @param updateExample 更新内容的键值实例
     * @return 是否更新成功
     */
    public boolean updateByIds(Collection<P> ids, T updateExample) {
        Optional<UpdateDefinition> updateDefinitionOpt = buildUpdate(updateExample);

        if (Boolean.FALSE == updateDefinitionOpt.isPresent()) {
            log.debug("更新对象为空");
            return false;
        }

        UpdateResult updateResult = mongoTemplate.updateMulti(getInPksQuery(ids), updateDefinitionOpt.get(), getEnityClass());
        if (log.isDebugEnabled()) {
            log.debug("更新了{}条记录", updateResult.getMatchedCount());
        }
        return updateResult.getMatchedCount() > 0;
    }

    /**
     * 根据搜索条件进行更新
     *
     * @param queryCriteria 查询条件
     * @param updateExample 更新内容的键值实例
     * @return 是否更新成功
     */
    public boolean update(Criteria queryCriteria, T updateExample) {

        Optional<UpdateDefinition> updateDefinitionOpt = buildUpdate(updateExample);

        if (Boolean.FALSE == updateDefinitionOpt.isPresent()) {
            log.debug("更新对象为空");
            return false;
        }

        UpdateResult updateResult = mongoTemplate.updateFirst(new Query().addCriteria(queryCriteria), updateDefinitionOpt.get(), getEnityClass());
        if (log.isDebugEnabled()) {
            log.debug("更新了{}条记录", updateResult.getMatchedCount());
        }
        return updateResult.getMatchedCount() > 0;
    }

    /**
     * 根据搜索条件进行更新
     *
     * @param query            查询条件
     * @param updateDefinition 更新内容的键值实例
     * @return 是否更新成功
     */
    public boolean update(Query query, UpdateDefinition updateDefinition) {

        if (Objects.isNull(updateDefinition)) {
            log.debug("更新对象为空");
            return false;
        }

        UpdateResult updateResult = mongoTemplate.updateMulti(query, updateDefinition, getEnityClass());
        if (log.isDebugEnabled()) {
            log.debug("更新了{}条记录", updateResult.getMatchedCount());
        }
        return updateResult.getMatchedCount() > 0;
    }

    /**
     * 根据搜索条件进行更新
     *
     * @param query               查询条件
     * @param updateDefinitionMap 更新内容的键值实例
     * @return 是否更新成功
     */
    public boolean update(Query query, Map<String, Object> updateDefinitionMap) {

        if (MapUtil.isEmpty(updateDefinitionMap)) {
            log.debug("更新对象为空");
            return false;
        }

        Update updateDefinition = new Update();

        updateDefinitionMap.entrySet().parallelStream().forEach(entry -> {
            // 主键不更新
            if (PK_FIELD_NAME.equals(entry.getKey())) {
                log.debug("主键不能更新，跳过");
                return;
            }
            updateDefinition.set(entry.getKey(), entry.getValue());
        });

        UpdateResult updateResult = mongoTemplate.updateMulti(query, updateDefinition, getEnityClass());
        if (log.isDebugEnabled()) {
            log.debug("更新了{}条记录", updateResult.getMatchedCount());
        }
        return updateResult.getMatchedCount() > 0;
    }

    public static <T extends MongoDbEo<P>, P extends Serializable> Optional<UpdateDefinition> buildUpdate(T updateExample) {
        Map<String, Object> updateMap = BeanUtil.beanToMap(updateExample, false, true);
        if (updateMap.isEmpty()) {
            return Optional.empty();
        }
        Update updateDefinition = new Update();
        updateMap.entrySet().parallelStream().forEach(entry -> {
            // 主键不更新
            if (PK_FIELD_NAME.equals(entry.getKey())) {
                log.debug("主键不能更新，跳过");
                return;
            }
            updateDefinition.set(entry.getKey(), entry.getValue());
        });
        return Optional.of(updateDefinition);
    }

    /**
     * 根据实例查询
     *
     * @param example 查询实例
     * @return 查询结果
     */
    @NonNull
    public List<T> findAllByExample(@NonNull Object example) {
        return findAllByExample(example, null);
    }

    /**
     * 根据实例查询
     *
     * @param query 查询条件
     * @return 查询结果
     */
    @NonNull
    public List<T> findAll(@NonNull Query query) {
        return CollUtil.emptyIfNull(mongoTemplate.find(query, getEnityClass()));
    }

    /**
     * 根据实例查询
     *
     * @param example 查询实例
     * @return 查询结果
     */
    @NonNull
    public List<T> findAllByExample(@NonNull Object example, @Nullable ExtQueryConsumerFunc extQueryConsumerFunc) {

        Assert.isFalse(BeanUtil.isEmpty(example), "查询实例不能为空");

        // 创建查询
        Query query = buildQuery(example, getEnityClass());

        if (Objects.nonNull(extQueryConsumerFunc)) {
            extQueryConsumerFunc.accept(query);
        }

        List<T> resultData = this.findAll(query);
        if (log.isDebugEnabled()) {
            log.debug("查询结果: {}", JSONUtil.toJsonStr(resultData));
        }
        return resultData;
    }

    /**
     * 根据实例查询，返回流
     *
     * @param example 查询实例
     * @return 结果流
     */
    @NonNull
    public Stream<T> findByExampleStream(@NonNull Object example, @Nullable ExtQueryConsumerFunc extQueryConsumerFunc) {

        if (Boolean.FALSE == example instanceof Map) {
            Assert.isFalse(BeanUtil.isEmpty(example), "查询实例不能为空");
        }

        // 创建查询
        Query query = buildQuery(example, getEnityClass());

        if (Objects.nonNull(extQueryConsumerFunc)) {
            extQueryConsumerFunc.accept(query);
        }
        return mongoTemplate.stream(query, getEnityClass());
    }

    /**
     * 根据实例查询，返回流
     *
     * @param example 查询实例
     * @return 结果流
     */
    @NonNull
    public Stream<T> findByExampleStream(@NonNull Object example) {
        return findByExampleStream(example, null);
    }

    /**
     * 获取根据 ID 查询的Query
     *
     * @param id  主键ID
     * @param <P> 主键类型
     * @return Query
     */
    @NonNull
    private static <P extends Serializable> Query getPkQuery(@NonNull P id) {
        return new Query(Criteria.where(MONGO_DB_ID_NAME).is(id));
    }

    /**
     * 获取根据 ID 集合查询的Query
     *
     * @param ids 主键ID集合
     * @param <P> 主键类型
     * @return Query
     */
    @NonNull
    private static <P extends Serializable> Query getInPksQuery(@NonNull Collection<P> ids) {
        return new Query(Criteria.where(MONGO_DB_ID_NAME).in(Set.copyOf(CollUtil.removeNull(ids))));
    }

    /**
     * 根据 ID 查询
     *
     * @param id 主键ID
     */
    public Optional<T> findOptById(P id) {
        return Optional.ofNullable(mongoTemplate.findById(id, getEnityClass()));
    }

    /**
     * 根据查询实例查询1个结果
     *
     * @param example 查询实例
     * @return 查询结果
     */
    @Nullable
    public T findOne(T example) {
        Query query = buildQuery(example, getEnityClass());
        return mongoTemplate.findOne(query, getEnityClass());
    }

    /**
     * 根据查询实例查询1个结果
     *
     * @param example 查询实例
     * @return 查询结果
     */
    public Optional<T> findOneOpt(T example) {
        return Optional.ofNullable(findOne(example));
    }

    /**
     * 查询任意一条记录
     *
     * @return 查询结果
     */
    public Optional<T> findAnyOne() {
        return Optional.ofNullable(mongoTemplate.findOne(new Query(), getEnityClass()));
    }

    /**
     * 查询所有
     *
     * @return 查询结果流
     */
    @NonNull
    public Stream<T> findAllStream() {
        return mongoTemplate.stream(new Query(), getEnityClass());
    }

    /**
     * 查询所有
     *
     * @param query 查询条件
     * @return 查询结果流
     */
    @NonNull
    public Stream<T> findAllStream(@NonNull Query query) {
        return mongoTemplate.stream(query, getEnityClass());
    }

    /**
     * 根据实例查询所有
     *
     * @param example 查询实例
     * @return 查询结果流
     */
    @NonNull
    public Stream<T> findAllStreamByExample(@NonNull Object example) {
        return mongoTemplate.stream(buildQuery(example, getEnityClass()), getEnityClass());
    }

    public MongoTemplate getMongoTemplate() {
        return mongoTemplate;
    }

    /**
     * 获取实体类
     *
     * @return 实体类
     */
    protected abstract Class<T> getEnityClass();

    @Resource
    public void setMongoTemplate(MongoTemplate mongoTemplate) {
        this.mongoTemplate = mongoTemplate;
    }
}
